#! /usr/bin/env python
"""
Acoustic Startle Program
Python version

This program generates sound in one channel (L), followed by a startle stimulus
(noise burst) in the other channel (R).
The conditioning sound can be one of:
tone pip before the startle
tone with gap before the startle
bandpass noise burst
bandpass noise burst with gap before the startle

Output hardware is either an National Instruments DAC card or a system sound card
If the NI DAC is available, TDT system 3 hardware is assumed as well for the
attenuators (PA5) and an RP2.1 to input the startle response.
Second channel of RP2.1 is collected as well. Use this for a microphone input
to monitor sound in the chamber.

Requires: startle2.rco, generated by RPvds (from TDT)

Python 2.5
PyQt4, Qt Designer (for Gui)
scipy, pylab, numpy, matplotlib
pyaudio
stack is optional

Works with Enthought distribution on Mac OS X and Windows.

November, 2008. Paul B. Manis, Ph.D.
UNC Chapel Hill
Supported by NIH Grant DC000425-22

"""
################################################################################

import sys, re, os
import datetime, time
from time import strftime
import struct, ctypes
from PyQt4 import Qt, QtCore, QtGui
from sets import *

from pylab import *
import scipy, numpy
import scipy.signal
import numpy
from numpy.fft import fft
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt

# non-standard stuff:
import stack
import pyaudio
from random import sample

# our gui:
from PyStartle_gui import Ui_MainWindow

################################################################################
# the first thing we must do is find out what hardware is available and what
# system we are on.
################################################################################
print "PyStartle: Checking Hardware and OS"
try:
    if os.name is not 'nt':
        assert 0 # force use of pyaudio if not on windows xp/nt.
    print "OS is Windows (NT or XP)"
    # get the drivers and the activeX control (win32com)
    from nidaq import NIDAQ as n
    import nidaq
    import win32com.client
    
    print "Attemtp to Assert num devs > 0:",
    assert(len(n.listDevices()) > 0)
    print "  OK"
    print "devices: %s" % n.listDevices()
    hwerr = 0
    print "getDevice:",
    dev0 = n.getDevice('Dev2')
    print "  ", dev0
    
    print "\nAnalog Channels:",
    # print "  AI: ", dev0.listAIChannels()
    print "  AO: ", dev0.listAOChannels() # check output only
    
    # active x connection to attenuators
    # note - variables set at this scope level are global to source file
    PA5 = win32com.client.Dispatch("PA5.x")
    a=PA5.ConnectPA5("USB", 1)
    if a > 0:
        print "Connected to PA5 Attenuator 1"
    else:
        print "Failed to connect to PA5 Attenuator 1"
        hwerr = 1
    PA5.SetAtten(120.0)
    a = PA5.ConnectPA5("USB", 2)
    if a > 0:
        print "Connected to PA5 Attenuator 2"
    else:
        print "Failed to connect to PA5 Attenuator 2"
        hwerr = 1
        
    PA5.SetAtten(120.0)
    RP21 = win32com.client.Dispatch("RPco.x") # connect to RP2.1
    a = RP21.ConnectRP2("USB", 1)
    if a > 0:
        print "RP2.1 Connect is good: %d" % (a)
    else:
        print "Failed to connect to PA5 Attenuator 1"
        hwerr = 1
    RP21.ClearCOF()
    samp_cof_flag = 2 # 2 is for 24.4 kHz
    samp_flist = [6103.5256125, 12210.703125, 24414.0625, 48828.125, 
    97656.25, 195312.5]
    if samp_cof_flag > 5:
        samp_cof_flag = 5
    a = RP21.LoadCOFsf("C:\pyStartle\startle2.rco", samp_cof_flag)
    if a > 0:
        print "Connected to TDT RP2.1 and startle2.rco is loaded"
    else:
        print "Error loading startle2.rco?, error = %d" % (a)
        hwerr = 1
    def_sampleRate = 100000
    my_hardware = 'nidaq'
    if hwerr == 1:
        print "?? Error connecting to hardware"
        exit()
        
except:
    print "OS/hardware only supports standard sound card audio"
    my_hardware = 'pyaudio'
    def_sampleRate = 44100

REF_ES_dB = 86.0; # calibration info -  Assumes 10 dB padding with attenuator.
REF_ES_volt = 2.0; # output in volts to get refdb
REF_MAG_dB = 100.0; # right speaker is mag... different scaling.

print "PyStartle is running with output hardware: %s" % (my_hardware)
    
################################################################################
# One class for the program: PyStartle
################################################################################

class PyStartle(QtGui.QMainWindow):
    
    def __init__(self):
        """ In the constructor get the application
            started byconstructing a basic QApplication with
            its __init__ method, then adding our slot/signal connections
            and finally starting 
            the exec_loop. """""
        QtGui.QDialog.__init__(self)
 
        self.AutoSave = True
        self.maxptsplot = 10000
        self.SF = 44100 # for ni board or audio board output
        self.RPSF = 12000
        self.ch1 = []
        self.ch2 = []
        self.response_tb = []# response time base
        self.stim_tb = []# stimuluation time base (not implemented yet..)
        self.PostDuration = 0.35 # seconds after startle  ends to record response
        self.PPGo = False
        self.PP_Notch_F1 = 12000.0 # set defaults for the notch - not in gui yet
        self.PP_Notch_F2 = 14000.0
        self.fileDate = ''
        self.Description = "Acoustic Startle Parameters"
        self.CurrentTab = 0 # set a default current tab - left most entry
        self.stack1 = stack.Stack()	 # init the stack mode used for holding MOUSE event data
        self.stack2 = stack.Stack()	 # init the stack mode used for holding MOUSE event data
        # We pass None since it's the top-level widget, we could in fact leave 
        # that one out, but this way it's easier to add more dialogs or widgets.
        self.ui = Ui_MainWindow() # this is the ONE THING
        
        self.ui.setupUi(self)
         
        self.connect(self.ui.QuitButton,QtCore.SIGNAL("clicked()"),
                     self.slotQuit)
        self.connect(self.ui.actionQuit,QtCore.SIGNAL("clicked()"),
             self.slotQuit) 
        self.connect(self.ui.actionOpen,QtCore.SIGNAL("clicked()"),
             self.Analysis_Read) 

        self.connect(self.ui.CloseDataWindows,QtCore.SIGNAL("clicked()"),
                     self.slotCloseDataWindows)
        self.connect(self.ui.ToneTest,QtCore.SIGNAL("clicked()"),
                     self.ToneTest)
        self.connect(self.ui.NoiseTest,QtCore.SIGNAL("clicked()"),
                     self.NoiseTest)
        self.connect(self.ui.PrePulse_Run,QtCore.SIGNAL("clicked()"),
                     self.PrePulseStart)
        self.connect(self.ui.PrePulse_Stop,QtCore.SIGNAL("clicked()"),
                     self.PrePulseStop)
        self.connect(self.ui.Save_Params,QtCore.SIGNAL("clicked()"), 
                     self.writeini)
        self.connect(self.ui.Load_Params,QtCore.SIGNAL("clicked()"), self.readini)
        self.connect(self.ui.Write_Data,QtCore.SIGNAL("clicked()"),
                     self.Write_Data)
        self.connect(self.ui.Analysis_Read,QtCore.SIGNAL("clicked()"),
                     self.Analysis_Read)
        self.connect(self.ui.Analysis_Test,QtCore.SIGNAL("clicked()"),
                     self.Analysis_Test)
        self.connect(self.ui.Analysis_Analyze,QtCore.SIGNAL("clicked()"),
                     self.Analyze_Data)
        self.TrialTimer=QtCore.QTimer() # get a Q timer
        self.connect(self.TrialTimer, QtCore.SIGNAL("timeout()"), self.NextTrial);
        # timer calls NextTrial when timed out

        self.readini("pystartle.ini") # read the initialization file if it is there.
        self.setMainWindow('default')     
        self.statusBar().showMessage("No File" )   
        self.Status('Welcome to PyStartle V0.8')

################################################################################
# utility routines for Gui:
# close the windows and exit
#
 
    def slotQuit(self):
        try:
            if my_hardware == 'nidaq':
                RP21.Halt() # make sure the RP21 is stopped.
        finally:
            pass
        self.slotCloseDataWindows() # should close the matplotlib windows... 
        QtCore.QCoreApplication.quit()

#
# just close the data plot windows (matplotlib windows)
#
    def slotCloseDataWindows(self):
        for i in range(1,5):
            try:
                plt.close(i)
            except AttributeError:
                pass

    def getCurrentTab(self):
       self.CurrentTab = self.ui.AcquisitionTabs.currentIndex()
       return(self.CurrentTab)

    def setCurrentTab(self, tab = 0):
       self.ui.AcquisitionTabs.setCurrentIndex(tab)
       
# update status window
#
#
    def Status(self, text, clear = 0):
        self.ui.Status_Window.insertItem(0, '[' +
                                         datetime.datetime.now().ctime() + ']  ' + text)
        item = self.ui.Status_Window.item(0) # get top item object
        self.ui.Status_Window.setCurrentItem(item)
        self.ui.Status_Window.update() # force an update with every line

    def setMainWindow(self, text):
        self.setWindowTitle("PyStartle [%s]" % (text))
        
# figure title for matplotlib window... 
    def putTitle(self, infotext):
       pa, fname = os.path.split(self.fileName)
       titletext = 'File: %s  R:[' % (fname)
       for i in self.reclist:
           titletext = titletext + '%d ' % (i)
       titletext = titletext + '] B:[ '
       for i in self.blocklist:
           titletext = titletext + '%d ' % (i)
       titletext = titletext + '] ' + infotext
       gcf().text(0.5, 0.95, titletext, horizontalalignment='center',
                  fontproperties=FontProperties(size=12))
       
       
# 
# Handle mouse events in matlab windows.
# onclick1 does "figure1" events - the raw data traces
# onclick2 does "figure2" events.- the analyzed data

    def onclick1(self, event): 
        self.ui.lcdXNumber.display(event.xdata)
        self.ui.lcdYNumber.display(event.ydata)
        # print event.xdata
        self.stack1.push((event.xdata, event.ydata))
        self.Status( "stack1: %d items (%8.3f %8.3f)" % (self.stack1.num_items(),
                                                         event.xdata, event.ydata) )
#	print '1: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % (event.button, event.x, event.y, event.xdata, event.ydata) 

    def onclick2(self, event): 
        self.ui.lcdXNumber.display(event.xdata)
        self.ui.lcdYNumber.display(event.ydata)
        self.stack2.push((event.xdata, event.ydata))
        self.Status( "stack2: %d items" % (self.stack2.num_items()) )
#	print '2: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % (event.button, event.x, event.y, event.xdata, event.ydata) 
################################################################################


################################################################################
# Read the gui data into our local parameters
################################################################################
    def readParameters(self):

#
        self.AutoSave = self.ui.AutoSave.isChecked()

# from the Levels and  Durations tab:
        self.CN_Level = self.ui.Condition_Level.value()
        self.CN_Dur = self.ui.Condition_Dur.value()
        self.CN_Var = self.ui.Condition_Var.value()
        self.PP_Level = self.ui.PrePulse_Level.value()
        self.PP_OffLevel = self.ui.PrePulse_Off_Level.value()
        self.PP_Dur = self.ui.PrePulse_Dur.value()
        self.PS_Dur = self.ui.PreStartle_Dur.value()
        self.ST_Dur = self.ui.Startle_Dur.value()
        self.ST_Level = self.ui.Startle_Level.value()
        self.StimEnable = self.ui.Stimulus_Enable.isChecked()
        self.WavePlot = self.ui.Waveform_PlotFlag.isChecked()
        self.ShowSpectrum = self.ui.OnlineSpectrum_Flag.isChecked()
        self.OnLineAnalysis = self.ui.OnlineAnalysis_Flag.isChecked()
        
# from the Waveforms tab:        
        self.PP_Freq = self.ui.PrePulse_Freq.value()
        self.PP_HP = self.ui.PrePulse_HP.value()
        self.PP_LP = self.ui.PrePulse_LP.value()
        self.PP_Mode = self.ui.Waveform_PrePulse.currentIndex()
        self.CN_Mode = self.ui.Waveform_Conditioning.currentIndex()
        self.PP_GapFlag = self.ui.PrePulse_GapFlag.isChecked()
        self.PP_Notch_F1 = self.ui.PrePulse_Notch_F1.value()
        self.PP_Notch_F2 = self.ui.PrePulse_Notch_F2.value()
        self.PP_MultiFreq = str(self.ui.PrePulse_MultiFreq.text())
        

# from the Timing and Trials tab:
        self.ITI_Var = self.ui.PrePulse_ITI_Var.value()
        self.ITI = self.ui.PrePulse_ITI.value()
        self.Trials = self.ui.PrePulse_Trials.value()
        self.NHabTrials = self.ui.PrePulse_NHabTrials.value()

# from the analysis tab:
        self.readAnalysisTab()
        
        
    def readAnalysisTab(self): # we call this elsewhere, - define for convenience
        self.Analysis_Start = self.ui.Analysis_Start.value()
        self.Analysis_End = self.ui.Analysis_Duration.value()
        self.Analysis_HPF = self.ui.Analysis_HPF.value()
        self.Analysis_LPF = self.ui.Analysis_LPF.value()
        

    def ToneTest(self):
        self.readParameters()
        w = self.StimulusMaker(mode='tone', freq = (self.PP_Freq, 0),
                               duration = self.PP_Dur, plotSignal = True)
          
    def NoiseTest(self):
        self.readParameters()
        w = self.StimulusMaker(mode = 'bpnoise', freq=(self.PP_HP, self.PP_LP),
                               duration=self.PP_Dur, plotSignal = True)
        

    
################################################################################
#
#   PrePulseRun controls the stimulus presentation and timing.
#   It is the main event loop during stimulation/acquisition.
#
# note : we use QTimer for the timing. One instance is generated with the
# main init routine above. We then start this and run it as a separate thread
# Allows gui interaction during data acquisition/stimulation and ability to
# stop the presentation cleanly.
################################################################################

    def PrePulseStart(self):
        if self.PPGo:
            print "already running"
            return;
        self.Status ("Starting Run")
#
# open and build the file
#
        dt = strftime('%Y%m%d%H%M')
        self.fn = dt + "_Startle.txt"
        
        self.readParameters() # get the parameters for stimulation
        self.TrialCounter = 0
        itil = self.ITI + self.ITI_Var*(rand(1, self.Trials+self.NHabTrials)-0.5)
        self.ITI_List = itil.reshape(max(shape(itil)))
        stimd = self.CN_Dur + self.CN_Var*(rand(1, self.Trials+self.NHabTrials)-0.5)
        self.Dur_List = stimd.reshape(max(shape(stimd)))
        self.Gap_List = int(self.Trials+self.NHabTrials)*[False]
        list = int(self.Trials/2)*[False, True]
        s=sample(list, int(self.Trials))
        self.Gap_List[int(self.NHabTrials):] = s
        if self.AutoSave:
            self.writeDataFileHeader(self.fn) # wait to write header until we have all the values.
        
        self.Gap_StartleMagnitude = zeros(self.Trials)
        self.Gap_Counter = 0
        self.noGap_StartleMagnitude = zeros(self.Trials)
        self.noGap_Counter = 0
        self.PPGo = True
        self.TrialTimer.setSingleShot(True)
        self.TrialTimer.start(10) # start right away

# catch the stop button press 
    def PrePulseStop(self):
        self.setAttens() # attenuators down
        self.HwOff() # turn hardware off
        self.PPGo=False # signal the prepulse while loop that  we are stopping
        self.statusBar().showMessage("Stimulus/Acquisition Events stopped")

# callback routine to stop timer when thread times out.

    def NextTrial(self):
        self.TrialTimer.stop()
        if self.TrialCounter <= self.Trials and self.PPGo:
            self.statusBar().showMessage("Rep: %d of %d" % (self.TrialCounter+1,
                                                            self.Trials+self.NHabTrials))
            DoneTime = self.ITI_List[self.TrialCounter] # get this before we start stimulus so stim time is included
            self.TrialTimer.start(int(1000*DoneTime))
            self.Stim_Dur = self.Dur_List[self.TrialCounter] # randomize the durations a bit too
            self.runOnePP()
            if self.WavePlot == True:
                self.plotSignal(self.wave_outL, self.wave_outR,
                                float(1000.0)/float(def_sampleRate))
            self.Status('sent signal')
            if self.AutoSave:
                self.AppendData(self.fn)
                self.Status('appended data')
            self.TrialCounter = self.TrialCounter + 1
        else:
            self.PPGo = False
            self.statusBar().showMessage("Test Complete")

################################################################################
# runOnePP - "run one prepulse" trial.
# Generate one stimulus set based on the choice. Builds both channels.
# Presents the stimuli if the flag is set.
################################################################################

    def runOnePP(self):
#        print "runOnePP at elapsed time = %9.3f" % time.time()

# the modes parse as follows (same modes apply for CN/PS, and for PP)
# 0 is silence
# 1 is tone
# 2 is bandpass noise
# 3 is notch noise (not implemented yet)
# 4 is multi tones (not implemented yet)
# 5 is AM tones (not implemented yet)
# 6 is AM Noise (not implemented yet)
# 
# The conditioning (CN) and the prepulse (PP) can be any of the above
# the pre-startle (post prepulse) is always the same as the conditioning.
# The conditioning stimulus always runs the whole duration (including through the end of the startle)
# If the conditioning stimulus is not the same as the prepulse, then the conditiioning
# will be interrupted by a gap during the prepulse period, and the prepulse will be calculated,
# shaped, and added during the prepulse period.
#
        if self.CN_Mode == 0:
            cnmode = 'silence'
            cnfreq = (self.PP_Freq, 0) # anything will do
        if self.CN_Mode == 1 or self.CN_Mode == 4 or self.CN_Mode == 5:
            cnmode = 'tone'
            cnfreq = (self.PP_Freq, 0)
        if self.CN_Mode == 2 or self.CN_Mode == 6:
            cnmode = 'bpnoise'
            cnfreq = (self.PP_HP, self.PP_LP)
        if self.CN_Mode == 3:
            cnmode = 'notchnoise' # Note: notch is embedded into a bandpass noise
            cnfreq = (self.PP_HP, self.PP_LP, self.PP_Notch_F1, self.PP_Notch_F2)
        # generat the conditioning stimulus and the post-prepulse stimulus
        print cnmode
        self.wave_outL = self.StimulusMaker(mode = cnmode, duration = (self.Stim_Dur+self.PP_Dur+self.PS_Dur+self.ST_Dur),
                                  freq = cnfreq, samplefreq = def_sampleRate, delay=0, level = self.CN_Level)
        # now tailor the conditioning stimulus
        # this is regulated by the current Gap_List value
        w_pp = [] # default with no prepulse
        if self.Gap_List[self.TrialCounter]: # only make a prepulse if it is set
            if self.PP_Mode == 0 or self.PP_GapFlag: # insert a gap
                self.wave_outL = self.insertGap(self.wave_outL, delay = self.Stim_Dur,
                                  duration = self.PP_Dur, samplefreq = def_sampleRate) # inserts the gap
            
            if self.PP_Mode == 1 or self.PP_Mode ==4 or self.PP_Mode == 5: # now insert a tone
                w_pp = self.StimulusMaker(mode = 'tone', duration = self.PP_Dur, freq = (self.PP_Freq, 0),
                                          delay=self.Stim_Dur, samplefreq = def_sampleRate, level = self.PP_Level)
                w_pp = append(w_pp, numpy.zeros(len(self.wave_outL)-len(w_pp))) # pad
                
            if self.PP_Mode == 2 or self.PP_Mode == 6:  # 2 is bandpass noise
                w_pp = self.StimulusMaker(mode = 'bpnoise', duration = self.PP_Dur, freq = (self.PP_HP, self.PP_LP),
                                    delay=self.Stim_Dur, samplefreq = def_sampleRate, level = self.PP_Level)
                w_pp = append(w_pp, numpy.zeros(len(self.wave_outL)-len(w_pp))) # pad
                
            if self.PP_Mode == 3: # 3 Notched noise
                w_pp = self.StimulusMaker(mode = 'notchnoise', duration = self.Stim_Dur,
                                    freq = (self.PP_HP, self.PP_LP, self.Notch_F1, self.Notch_F2),
                                    samplefreq = def_sampleRate, delay=self.Stim_Dur,
                                    level = self.PP_Level)
                w_pp = append(w_pp, numpy.zeros(len(self.wave_outL)-len(w_pp))) # pad 
        
        if len(w_pp) > 0:
            self.wave_outL = self.wave_outL + w_pp
            
        # generate the startle sound. Note that it overlaps the end of the conditioning sound...
        self.wave_outR = self.StimulusMaker(mode = 'bpnoise', delay = (self.Stim_Dur+self.PP_Dur+self.PS_Dur),
                                       duration = self.ST_Dur, samplefreq=def_sampleRate,
                                       freq = (1000.0, 32000.0), level = self.ST_Level,
                                       channel = 1)
        lenL = len(self.wave_outL)
        lenR = len(self.wave_outR)
        if lenR > lenL:
            self.wave_outL = append(self.wave_outL, numpy.zeros(lenR-lenL))
        if lenL > lenR:
            self.wave_outR = append(self.wave_outR, numpy.zeros(lenL-lenR))
        
        if self.StimEnable == True:
            self.playSound(self.wave_outL, self.wave_outR, def_sampleRate)
 
        
################################################################################
# STIMULUS GENERATION ROUTINES
#
# transcribed from Matlab. P. Manis, Nov. 28-December 1 2008.
################################################################################
    def StimulusMaker(self, mode = 'tone', amp = 1, freq = (1000, 3000, 4000), delay = 0, duration = 2000,
                  rf = 2.5, phase0 = 0, samplefreq = 44100, ipi = 20, np = 1, alternate = 1, level = 70,
                  playSignal = False, plotSignal= False, channel = 0):
# generate a tsound (tone, bb noise, bpnoise)  pip with amplitude (V), frequency (Hz) (or frequencies, using a tuple)
# delay (msec), duration (msec).
# if no rf (risefall) time is given, cosine^2 shaping with 5 msec ramp duration is applied.
# if no phase is given, phase starts on 0, with positive slope.
# level is in dB SPL as given by the reference calibration data above...
#
        clock = 1000.0/samplefreq # calculate the sample clock rate - msec (khz)
        uclock = 1000.*clock # microsecond clock
        phi = 2*pi*phase0/360.0 # convert phase from degrees to radians...
        Fs = 1000/clock
        phi = 0 # actually, always 0 phase for start
        w = []
        fil = self.rfShape(0, duration, clock, rf) # make the shape filter with 0 delay
        jd = int(floor(delay/clock)) # beginning of signal buildup (delay time)
        if jd < 0:
            jd = 0
        jpts = arange(0,len(fil))
        signal = numpy.zeros(len(jpts))
        siglen = len(signal)
 
        if mode =='tone':
            for i in range(0, len(freq)):
                signal = signal + fil*amp*sin(2*pi*freq[i]*jpts/Fs)
                self.Status("Generated Tone at %7.1fHz" % (freq[i]))
                
        if mode == 'bbnoise':
            signal = signal + fil*amp*normal(0,1,siglen)
            self.Status("BroadBand Noise " )
            
        if mode == 'bpnoise':
            tsignal = fil*amp*normal(0,1,siglen)
            # use freq[0] and freq[1] to set bandpass on the noise
#            print "freqs: HP: %6.1f    LP: %6.1f" % (freq[0], freq[1])
            wp = [float(freq[0])/samplefreq*2, float(freq[1])/samplefreq*2]
            ws = [0.75*float(freq[0])/samplefreq*2, 1.25*float(freq[1])/samplefreq*2]
            filter_b,filter_a=scipy.signal.iirdesign(wp, ws,
                    gpass=1.0,
                    gstop=60.0,
                    ftype="ellip")
            self.Status("BandPass Noise %7.1f-%7.1f" % (freq[0], freq[1]))
            signal=scipy.signal.lfilter(filter_b, filter_a, tsignal)
        
        if mode == 'notchnoise':
            return array(signal)
            
        if mode == 'multitones':
            return array(signal)

        if mode == 'silence':
            return array(signal)

            
# now build the waveform from the components
        w = numpy.zeros(ceil(ipi*(np-1)/clock)+jd+siglen)
        sign = numpy.ones(np)
        if alternate == True:
            sign[range(1,np,2)] = -1
        id = int(floor(ipi/clock))
        for i in range(0, np): # for each pulse in the waveform
            j0 = jd + i*id # compute start time  
            w[range(j0,j0+siglen)] = sign[i]*signal
        
        w = w*self.dbconvert(spl = level, chan = channel) # aftera all the shaping ane scaling, we convert to generate a signal of w dB
        if playSignal == True:
            self.playSound(w, w, samplefreq)
        
        if plotSignal == True:
            self.plotSignal(w, w, clock)
        return array(w)


#
# Rise-fall shaping of a waveform. This routine generates an envelope with
# 1 as the signal max, and 0 as the baseline (off), with cosine^2 shaping of
# duration rf starting at delay (msec). The duration of the signal includes the
# rise and fall, so the duration of the signal at full amplitude is dur - 2*rf.
#
    def rfShape(self, delay=0, duration=100, clock=44100, rf=2.5):
        jd = int(floor(delay/clock)) # beginning of signal buildup (delay time)
        if jd < 0:
            jd = 0
        je = int(floor((delay+duration)/clock)) # end of signal decay (duration + delay)
        #
        # build sin^2 filter from 0 to 90deg for shaping the waveform
        #
        nf = int(floor(rf/clock)) # number of points in the filter
        fo = 1.0/(4.0*rf) # filter "frequency" in kHz - the 4 is because we use only 90deg for the rf component
        
        pts = arange(jd,jd+nf)
        fil = numpy.zeros(je)
        fil[range(jd,jd+nf)] = sin(2*pi*fo*pts*clock)**2 # filter
        fil[range(jd+nf,je-nf)] = 1        
        pts = range(je-nf,je)
        kpts = range(jd+nf,jd,-1)
        fil[pts] = fil[kpts]
        return(fil)

#
# insertGap takes a waveform and inserts a shaped gap into it.
# currently, gap is all the way off, i.e., 0 intensity.
# a future change is to include relative gap level (-dB from current waveform)
#

    def insertGap(self, wave, delay = 20, duration = 20, rf = 2.5, samplefreq = def_sampleRate):
        clock = 1000.0/samplefreq # calculate the sample clock rate - msec (khz)
        fil = self.rfShape(delay, duration, clock, rf) # make the shape filter with 0 delay
        lenf = len(fil)
        lenw = len(wave)
        if lenw > lenf:
            fil = append(fil, numpy.zeros(lenw-lenf))
        if lenf > lenw:
            fil = append(fil, numpy.zeros(lenf-lenw))
        return(wave*(1.0-fil))
        
        
#
# compute voltage from reference dB level
# db = 20 * log10 (Vsignal/Vref)
#
    def dbconvert(self, spl = 0, chan = 0):
        ref = REF_ES_dB
        if chan == 1:
            ref = REF_MAG_dB
        
        zeroref = REF_ES_volt/(10**(ref/20.0));
        sf = zeroref*10**(spl/20.0); # actually, the voltage needed to get spl out...
        print "scale = %f for %f dB" % (sf, spl)
        return (sf) # return a scale factor to multiply by a waveform normalized to 1 

################################################################################
# hardware interactions:
#
# set the attenuators on the PA5.
# If no args are given, set to max attenuation

    def setAttens(self, attenl = 120, attenr = 120):
        if my_hardware == 'nidaq':
            PA5.ConnectPA5("USB", 1)
            PA5.SetAtten(attenl)
            PA5.ConnectPA5("USB", 2)
            PA5.SetAtten(attenr)

#
# playSound sends the sound out to an audio device. In the absence of NI card
# and TDT system, it will use the system audio device (sound card, etc)
# The waveform is played in stereo.
#
    def playSound(self, wavel, waver, samplefreq = 44100):
        if my_hardware == 'pyaudio':
            self.audio = pyaudio.PyAudio()
            chunk = 1024
            FORMAT = pyaudio.paFloat32
            CHANNELS = 2
            RATE = samplefreq
            self.RPSF = samplefreq
            self.SF = self.RPSF
            self.stream = self.audio.open(format = FORMAT,
                            channels = CHANNELS,
                            rate = int(RATE),
                            output = True,
                            input = True,
                            frames_per_buffer = chunk)
            # play stream
            wave = zeros(2*len(wavel))
            if len(wavel) != len(waver):
                print "waves not matched in length: %d vs. %d (L,R)" % (len(wavel), len(waver))
                return
            (waver, clipr) = self.clip(waver, 20.0)
            (wavel, clipl) = self.clip(wavel, 20.0)
            wave[0::2] = waver 
            wave[1::2] = wavel  # order chosen so matches entymotic earphones on my macbookpro.
            postdur =  int(float(self.PostDuration*self.SF))
            rwave = self.read_array(len(wavel)+postdur, CHANNELS)
            self.write_array(wave)
            self.stream.stop_stream()
            self.stream.close()
            self.audio.terminate()
            self.ch1 = rwave[0::2]
            self.ch2 = rwave[1::2]
            
        if my_hardware == 'nidaq':
            self.task = dev0.createTask()  # creat a task for the NI 6731 board.
            self.task.CreateAOVoltageChan("/Dev2/ao0", "ao0", -10., 10.,
                                          nidaq.Val_Volts, None)
            self.task.CreateAOVoltageChan("/Dev2/ao1", "ao1", -10., 10.,
                                          nidaq.Val_Volts, None) # use 2 channels
            wlen = 2*len(wavel)
            self.task.CfgSampClkTiming(None, samplefreq, nidaq.Val_Rising,
                                       nidaq.Val_FiniteSamps, len(wavel))
            # DAQmxCfgDigEdgeStartTrig (taskHandle, "PFI0", DAQmx_Val_Rising);
            self.task.SetStartTrigType(nidaq.Val_DigEdge)
            self.task.CfgDigEdgeStartTrig('PFI0',  nidaq.Val_Rising)
            daqwave = numpy.zeros(wlen)
            (wavel, clipl) = self.clip(wavel, 10.0)
            (waver, clipr) = self.clip(waver, 10.0)
            
            daqwave[0:len(wavel)] = wavel
            daqwave[len(wavel):] = waver # concatenate channels (using "groupbychannel" in writeanalogf64)
            dur = wlen/float(samplefreq)
            self.task.write(daqwave)
            # now take in some acquisition...
            a = RP21.ClearCOF()
            if a <= 0:
                print "problem with RP21"
                return
            samp_cof_flag = 2 # 2 is for 24.4 kHz
            samp_flist = [6103.5256125, 12210.703125, 24414.0625, 48828.125, 
            97656.25, 195312.5]
            if samp_cof_flag > 5:
                samp_cof_flag = 5
            a = RP21.LoadCOFsf("C:\pyStartle\startle2.rco", samp_cof_flag)
            if a > 0:
                print "Connected to TDT RP2.1 and startle2.rco is loaded"
            else:
                print "Error loading startle2.rco?, error = %d" % (a)
                hwerr = 1
                return
            self.RPSF = RP21.GetSFreq()
            self.SF = self.RPSF
            Ndata = ceil(0.5*(dur+self.PostDuration)*self.RPSF)
            RP21.SetTagVal('REC_Size', Ndata)  # old version using serbuf  -- with
            # new version using SerialBuf, can't set data size - it is fixed.
            # however, old version could not read the data size tag value, so
            # could not determine when buffer was full/acquisition was done.
            self.setAttens(10.0,10.0) # set equal, but not at minimum...

            self.task.start() # start the NI AO task
            a=RP21.Run() # start the RP2.1 processor...
            a=RP21.SoftTrg(1) # and trigger it. RP2.1 will in turn start the ni card
            while not self.task.isTaskDone():  # wait for AO to finish?
                if not self.PPGo: # while waiting, check for stop.
                    RP21.Halt()
                    self.task.stop()
                    return
            self.task.stop() # done, so stop the output.
            self.setAttens() # attenuators down (there is noise otherwise)
            # read the data...
            curindex1=RP21.GetTagVal('Index1')
            curindex2=RP21.GetTagVal('Index2')
            while(curindex1 < Ndata or curindex2 < Ndata): # wait for input data to be sampled
                if not self.PPGo: # while waiting, check for stop.
                    RP21.Halt()
                    return
                curindex1=RP21.GetTagVal('Index1')
                curindex2=RP21.GetTagVal('Index2')
            self.task.stop()   
            self.ch2=RP21.ReadTagV('Data_out2', 0, Ndata)
            # ch2 = ch2 - mean(ch2[1:int(Ndata/20)]) # baseline: first 5% of trace
            self.ch1=RP21.ReadTagV('Data_out1', 0, Ndata)
            RP21.Halt()
    
    def HwOff(self): # turn the hardware off if you can.
        if my_hardware == 'pyaudio':
            self.stream.stop_stream()
            self.stream.close()
            self.audio.terminate()
        
        if my_hardware == 'nidaq':
            self.task.stop()
            self.setAttens()
            RP21.Halt()
            
# clip data to max value (+/-) to avoid problems with daqs
    def clip(self, data, maxval):
        clip = 0
        u = where(data >= maxval)
        ul = list(transpose(u).flat)
        if len(ul) > 0:
            data[ul] = maxval
            clip = 1 # set a flag in case we want to know
        v = where(data <= (-maxval))
        vl = list(transpose(u).flat)
        if len(vl) > 0:
            data[vl] = -maxval
            clip = 1
        return (data, clip)
        
        
#            print 'Data acquired ok, %d points' % (curindex1)
################################################################################
# the following was taken from #http://hlzr.net/docs/pyaudio.html
# it is used for reading and writing to the system audio devie
#
    def write_array(self, data):
        """
        Outputs a numpy array to the audio port, using PyAudio.
        """
        # Make Buffer
        buffer_size = struct.calcsize('@f') * len(data)
        output_buffer = ctypes.create_string_buffer(buffer_size)
    
        # Fill Up Buffer
        #struct needs @fffff, one f for each float
        format = '@' + 'f'*len(data)
        struct.pack_into(format, output_buffer, 0, *data)
    
        # Shove contents of buffer out audio port
        self.stream.write(output_buffer)
    
    def read_array(self, size, channels=1):
        input_str_buffer = self.stream.read(size)
        input_float_buffer = struct.unpack('@' + 'f'*size*channels, input_str_buffer)
        return numpy.array(input_float_buffer)
    
################################################################################
#
# plot the signal and it's power spectrum
#     
    def plotSignal(self, wL, wR, clock):
        npts = len(wL)
        t = clock*arange(0,npts)/1000
        self.datafig = plt.figure(1)
        skip = int(npts/self.maxptsplot)
        if skip < 1:
            skip = 1
        plt.clf()

        plt.axes([0.07, 0.7, 0.55, 0.27]) # top subplot has stimulus waveforms.
        plt.plot(t[0::skip], wL[0::skip], 'b-')  
        plt.hold(True)
        plt.plot(t[0::skip], wR[0::skip], 'r-')
        
        # spectrum of signal
        if self.ShowSpectrum:
            plt.axes([0.7, 0.7, 0.27, 0.27]) # top right subplot has power spectrum of Left signal
            (spectrum, freqAzero) = self.pSpectrum(wL, clock)
            plt.plot(freqAzero, 1000*spectrum)

        # response
        plt.axes([0.07, 0.37, 0.55, 0.12])
        ds = shape(self.ch1) 
        self.response_tb=float(1.0/self.SF)*arange(0,len(self.ch1))

        plt.plot(self.response_tb[0::skip], self.ch1[0::skip], 'g-')
        plt.axes([0.07, 0.49, 0.55, 0.12])
        plt.plot(self.response_tb[0::skip], self.ch2[0::skip], 'r-')

        # spectrum of the response
        self.SpecAxes = plt.axes([0.7, 0.37, 0.27, 0.27])
        self.SignalAxes = plt.axes([0.07, 0.05, 0.4, 0.27])
        self.ResponseAxes = plt.axes([0.57, 0.05, 0.4, 0.27])
        tdelay = self.Stim_Dur + self.PP_Dur + self.PS_Dur
        # analyze the response signal
        self.Response_Analysis(timebase= self.response_tb, signal = self.ch1,
                               rate = self.RPSF, delay=tdelay, SpecAxes = self.SpecAxes,
                               SignalAxes = self.SignalAxes,
                               ResponseAxes = self.ResponseAxes,
                               ntrials = self.Trials,
                               trialcounter = self.TrialCounter,
                               gaplist = self.Gap_List)
        if self.TrialCounter > 0:
            plt.figtext(0.82, 0.20, "Rd: %7.3f" % (dprime))
            self.ui.Rd_Dial.setValue(int(100*dprime))
        plt.show()
        plt.draw()
        

    def getSelectionIndices(self, x, xstart, xend):
        astart = where(x >= xstart)
        aend = where (x <= xend)
        s0 = Set(transpose(astart).flat)
        s1 = Set(transpose(aend).flat)
        xpts = list(s1.intersection(s0))
        return (xpts)
    
# compute the power spectrum.
# simple, no windowing etc...

    def pSpectrum(self, data, clock):
        npts = len(data)
# we should window the data here
        padw =  append(data, zeros(npts))
        npts = len(padw)
        spfft = fft(padw)
        nUniquePts = ceil((npts+1)/2.0)
        spfft = spfft[0:nUniquePts]
        spectrum = abs(spfft)
        spectrum = spectrum / float(npts) # scale by the number of points so that
                           # the magnitude does not depend on the length 
                           # of the signal or on its sampling frequency  
        spectrum = spectrum**2  # square it to get the power    
        spmax = amax(spectrum)
        spectrum = spectrum + 1e-12*spmax
        # multiply by two (see technical document for details)
        # odd nfft excludes Nyquist point
        if npts % 2 > 0: # we've got odd number of points fft
            spectrum[1:len(spectrum)] = spectrum[1:len(spectrum)] * 2
        else:
            spectrum[1:len(spectrum) -1] = spectrum[1:len(spectrum) - 1] * 2 # we've got even number of points fft

        freqAzero = arange(0, nUniquePts, 1.0) * ((1/clock) / npts)
        # print "min spec: %f\n" % (amin(spectrum))
        return(spectrum, freqAzero)

# filter signal with elliptical filter
    def SignalFilter(self, signal, LPF, HPF, samplefreq):
#        print "sfreq: %f LPF: %f HPF: %f" % (samplefreq, LPF, HPF)
        flpf = float(LPF)
        fhpf = float(HPF)
        sf = float(samplefreq)
        sf2 = sf/2
        wp = [fhpf/sf2, flpf/sf2]
        ws = [0.5*fhpf/sf2, 2*flpf/sf2]
        print "signalfilter: samplef: %f  wp: %f, %f  ws: %f, %f lpf: %f  hpf: %f" % (
            sf, wp[0], wp[1], ws[0], ws[1], flpf, fhpf)
        
        filter_b,filter_a=scipy.signal.iirdesign(wp, ws,
                gpass=1.0,
                gstop=60.0,
                ftype="ellip")
        w=scipy.signal.lfilter(filter_b, filter_a, signal) # filter the incoming signal
#        print "sig: %f-%f w: %f-%f" % (min(signal), max(signal), min(w), max(w))
        return(w)
    
    def Write_Data(self):
        self.writeDataFileHeader('test.dat')
        
    def writeDataFileHeader(self, filename):
        # make a dictionary of all the parameters
        filedict = {}
        filedict_gap = {}
        filedict_iti = {}
        filedict['CN_Level'] =  self.CN_Level
        filedict['CN_Dur'] = self.CN_Dur
        filedict['CN_Var'] = self.CN_Var
        filedict['PP_Level'] = self.PP_Level
        filedict['PP_OffLevel'] = self.PP_OffLevel 
        filedict['PP_Dur'] = self.PP_Dur
        filedict['PS_Dur'] = self.PS_Dur 
        filedict['ST_Dur'] = self.ST_Dur 
        filedict['ST_Level'] = self.ST_Level
        filedict['StimEnable'] = self.StimEnable
        filedict['WavePlot'] = self.WavePlot
        
        filedict_gap['GapList'] = self.Gap_List # save the sequencing information

# from the Waveforms tab:        
        filedict['PP_Freq'] = self.PP_Freq 
        filedict['PP_HP'] = self.PP_HP
        filedict['PP_LP'] = self.PP_LP
        filedict['PP_Mode'] = self.PP_Mode
        filedict['CN_Mode'] = self.CN_Mode
        filedict['PP_Notch_F1'] = self.PP_Notch_F1
        filedict['PP_Notch_F2'] = self.PP_Notch_F2
        filedict['PP_MultiFreq'] = self.PP_MultiFreq
        filedict['PP_GapFlag'] = self.PP_GapFlag
 
# from the Timing and Trials tab:
        filedict['ITI_Var'] = self.ITI_Var 
        filedict['ITI'] = self.ITI
        filedict['Trials'] = self.Trials
        filedict['NHabTrials'] = self.NHabTrials
# analysis parameters        
        filedict['Analysis_Start'] = self.Analysis_Start 
        filedict['Analysis_End'] = self.Analysis_End
        filedict['Analysis_HPF'] = self.Analysis_HPF
        filedict['Analysis_LPF'] = self.Analysis_LPF
            
        print "Writing File: %s" % (filename)
        hdat = open(filename, 'w')
        hdat.write("%s \n" % (filedict))
        hdat.write("%s \n" % (filedict_gap)) # write in separate lines
        hdat.close()
        
    def AppendData(self, filename):
        hdat = open(filename, 'a')
        datainfo = {}
        datainfo['Points'] = len(self.response_tb)
        datainfo['SampleRate'] = self.SF
        datainfo['GapMode'] = self.Gap_List[self.TrialCounter]
        datainfo['ITI'] = self.ITI_List[self.TrialCounter]
        datainfo['CNDur'] = self.Dur_List[self.TrialCounter]
        hdat.write("%s \n" % (datainfo))
        for i in range(0, len(self.response_tb)):
            hdat.write("%f %f %f\n" % (self.response_tb[i], 1000*self.ch1[i],
                                       1000*self.ch2[i]))
        hdat.close()

 
# do an eval on a long line (longer than 512 characters)
# assumes input is a dictionary that is too long
# parses by breaking the string down and then reconstructing each element
#
    def long_eval(self, line):
        sp = line.split(',')
        u = {}
        for di in sp:
            try:
                r = eval('{' + di.strip('{}') + '}')
                u[r.keys()[0]] = r[r.keys()[0]]
            except:
                continue
        return(u)

    def Analysis_Test(self):
        self.readParameters()
        self.readAnalysisTab()
        npts = 10000
        samplefreq = 24410.0
        rate = 1000.0/samplefreq
        signal = normal(0, 1, npts)
        (Rspectrum, Rfreqs) = self.pSpectrum(signal, float(rate/1000.0)) # rate  (1/ms) is converted to Hz
        maxFreq = 5000.0
        fa = self.SignalFilter(signal, self.Analysis_LPF, self.Analysis_HPF, samplefreq)
        (fRspectrum, fRfreqs) = self.pSpectrum(fa, float(rate/1000.0)) # rate  (1/ms) is converted to Hz
        figure(3)
        plt.clf()
        plt.plot(Rfreqs, Rspectrum, 'k-')
        hold(True)
        plt.plot(fRfreqs, fRspectrum, 'r-')
        plt.show()
        
################################################################################
#   Analysis routines
#
################################################################################
    def Analysis_Read(self, filename = None):
        self.readParameters() # to be sure we have "showspectrum"
        self.readAnalysisTab()
        if filename == None:
            fd = QtGui.QFileDialog(self)
            self.inFileName = str(fd.getOpenFileName(self, "Get input file", "",
                                                     "data files (*.txt)"))
        else:
            self.inFileName = filename
        try:
            hstat = open(self.inFileName,"r")
            (p, f)  = os.path.split(self.inFileName)
            self.setMainWindow(f)
        except IOError:
            self.Status( "%s not found" % (self.inFileName))
            return
        lineno = 0
        state = 0 # initial state
                    # states:
                    # 0 - nothing read
                    # 1 = first line read
                    # 2 = record "header" line read
                    # 3 = reading data
        parse =  re.compile("(^([\-0-9.]*) ([\-0-9.]*) ([\-0-9.]*))")
        reccount = 1
        self.a_t = [] 
        self.a_ch1 = [] 
        self.a_ch2 = [] 
        self.gapmode = []
        self.delaylist = []
        header_linecount = 0
        lineno = 0
        for line in hstat:
            lineno = lineno + 1
            if state == 0:
                if header_linecount == 0:
                    self.statusBar().showMessage("Reading Header" )   
                    self.paramdict = self.long_eval(line)
                    header_linecount = 1
                    continue
                if header_linecount == 1:
                    self.paramdict_gap = self.long_eval(line)
                    state = 1
                    header_linecount = 0
                    print ".",
                continue
            if state == 1:
                self.headerdict =  self.long_eval(line)
                self.npts = self.headerdict['Points']
                self.samplefreq = 1.0/float(self.headerdict['SampleRate'])
                if self.headerdict.has_key('GapMode'): # build gap mode array
                    self.gapmode.append(self.headerdict['GapMode'])
                else:
                    self.gapmode.append(False)
                if self.headerdict.has_key('CNDur'): # build duration  (delay to startle) array
                    self.delaylist.append(self.headerdict['CNDur'] +
                                          self.paramdict['PP_Dur'] + self.paramdict['PS_Dur'])                    
                reccount += 1
                self.statusBar().showMessage("Reading Trial %d" % (reccount) )   
                state = 2
                i = 0
                t = numpy.zeros(self.npts)
                ch1 = numpy.zeros(self.npts)
                ch2 = numpy.zeros(self.npts)
                print ".",
                continue
            if state == 2:
                    mo = parse.search(line)
                    t[i] = float(mo.group(2))
                    ch1[i] = float(mo.group(3))
                    ch2[i] = float(mo.group(4))
                    i = i + 1
                    if i >= self.npts:
                        self.a_t.append(numpy.array(t))
                        # filter the data as it comes in
                        self.a_ch1.append(numpy.array(ch1))
                        self.a_ch2.append(numpy.array(ch2))
                        state = 1 # reset the state to read the next points list
        hstat.close()
        self.statusBar().showMessage("Done Reading")   
        self.Analyze_Data()
        
    def Analyze_Data(self):
        self.readParameters() # to be sure we have "showspectrum"
        self.readAnalysisTab()
        ds = len(self.a_t)
        dshape = shape(self.a_t[1])
        srate = 1000.0/float(self.headerdict['SampleRate']) # sample rate is in msec/point.
        # note: must clip to the part of the dataset that we need - e.g., the post-startle section        
        stdur = int(self.Analysis_End/srate) #  points after startle
        plt.figure(2)
        plt.clf()
        rows = int(numpy.sqrt(ds))
        cols = int(ds/rows)
        if rows*cols < ds:
            cols += 1
        
        # first filter the data set (only the response channel, not the microphonse)
        self.fa_ch1 = []
        for k in range(0, ds):
            fa = self.SignalFilter(self.a_ch1[k], self.Analysis_LPF, self.Analysis_HPF, float(self.headerdict['SampleRate']))
            print "fa: %f-%f" % (min(fa), max(fa))
            self.fa_ch1.append((fa).astype('float32'))
  
        k = 0
        for i in range(0, rows):
            for j in range(0, cols):
                if k >= ds:
                    continue
                if self.gapmode[k]:
                    pline = 'r-' # with prepulse, red
                else:
                    pline = 'b-'
                if k < self.paramdict['NHabTrials']: # force green
                    pline = 'g-'
                plt.subplot(rows, cols, k+1)
                ststart = int(self.delaylist[k]/srate) # delay is in msec
                stend = ststart+stdur
                if self.gapmode[k]:       
                    try:
                        plt.plot(self.a_t[k][ststart:stend]-self.a_t[k][ststart],
                        self.fa_ch1[k][ststart:stend], pline, label="PP")
                    except:
                        pass
                else:
#                    print "ststart: %d  stend: %d len (a_t): %d" % (ststart, stend, len(self.a_t[k]))
                    try:
                        plt.plot(self.a_t[k][ststart:stend]-self.a_t[k][ststart],
                        self.fa_ch1[k][ststart:stend], pline, label="NoPP")
                    except:
                        pass
                k = k + 1 
        plt.figure(3)
        plt.clf()
        RAaxes = plt.subplot(2,2,3)
        SAaxes = plt.subplot(2,2,2)
        SPaxes = plt.subplot(2,2,4)
        sum_nogap = numpy.zeros(stdur)
        sum_gap = numpy.zeros(stdur)
        N_gap = 0
        N_nogap = 0
        tb = numpy.arange(0,(stdur/srate),1.0/srate)
        NTrials = int(self.paramdict['Trials'])
        self.Startle_Analyze(ntrials = NTrials) # forces init of variables...
        self.SpecMax = 0
        for i in range(0, ds):
#            print "i=%d" % (i)
            thislen = shape(self.a_t[i])
            if thislen[0] <= 0:
                break
#            print "continuing"
            ststart = int(self.delaylist[i]/srate) # delay is in msec
            stend = ststart+stdur
#            print "ststart: %d   stend: %d len(signal): %d" % (ststart, stend, len(self.a_ch1[i]))
            if i < int(self.paramdict['NHabTrials']):
                continue
#            print "past habituation trials"
            self.Response_Analysis(signal=self.fa_ch1[i][ststart:stend],
                                  rate = srate,
                                  ResponseAxes=RAaxes,
                                  SignalAxes = SAaxes,
                                  SpecAxes = SPaxes,
                                  trialcounter=i,
                                  ntrials=NTrials,
                                  gaplist = self.gapmode)
            if self.gapmode[i]:
                try:
                    sum_gap = sum_gap + numpy.array(self.fa_ch1[i][ststart:stend])
                    N_gap += 1
                except:
                    pass
            else:
                try:
                    sum_nogap = sum_nogap + numpy.array(self.fa_ch1[i][ststart:stend])
                    N_nogap += 1
                except:
                    pass

        sum_gap = sum_gap/float(N_gap)
        sum_nogap = sum_nogap/float(N_nogap)
#        print "stdur (int): %d" % (stdur)
#        print "N_Gap: %d  N_Nogap: %d " % (N_gap, N_nogap)
#        print max(sum_gap)
#        print max(sum_nogap)
        plt.subplot(2,2,1)
        plt.plot(tb, sum_nogap, 'b-', label="No PrePulse")
        plt.title('Average Startle')
        plt.hold(True)
        plt.plot(tb, sum_gap, 'r-', label="PrePulse")
        plt.legend()
        plt.draw()
        plt.show()
        

# response analysis for a single trace...

    def Response_Analysis(self, timebase = None, signal = None,
                               rate = None, delay=0, SpecAxes = None,
                               SignalAxes = None, ResponseAxes = None,
                               trialcounter = 0,
                               ntrials = 1,
                               gaplist = None):
        if gaplist[trialcounter]:
            pline = 'r-' # with prepulse, red
        else:
            pline = 'b-'

        if self.ShowSpectrum and SpecAxes != None:
            plt.axes(SpecAxes) # top right subplot has power spectrum of Left signal
#            print "rate: %f" % (float(rate/1000.0))
            (Rspectrum, Rfreqs) = self.pSpectrum(signal, float(rate/1000.0)) # rate  (1/ms) is converted to Hz
            maxFreq = 1000.0
            plt.plot(Rfreqs, Rspectrum, pline)
            if max(Rspectrum) > self.SpecMax:
                self.SpecMax = max(Rspectrum)
            plt.axis([0, maxFreq, 0, self.SpecMax])
        samplefreq = 1000.0/rate # convert back to Hz
        
        if timebase == None:
            timebase = arange(0, len(signal))*rate
        ana_windowstart = (delay + self.Analysis_Start)
        ana_windowend = (delay + self.Analysis_End)
        apts = self.getSelectionIndices(timebase, ana_windowstart, ana_windowend)
#        fsignal = self.SignalFilter(signal, self.Analysis_LPF, self.Analysis_HPF, samplefreq)
#        signal = array(self.ch2) # just for monitoring timing
        t0 = timebase[apts[0]] - self.Analysis_Start/1000.0
        if SignalAxes != None:
#            print "apts: %d-%d, l(tb): %d  l(signal): %d" % (min(apts), max(apts), len(timebase), len(signal))
#            print "signal: %d-%d" % (min(signal), max(signal))
            plt.axes(SignalAxes)
            plt.plot(timebase[apts]-t0, 1000.0*signal[apts], pline)
        dprime = self.Startle_Analyze(timebase, signal=signal, startdelay=0.0,
                                      trialcounter=trialcounter,
                                      ntrials=ntrials,
                                      gaplist = self.gapmode)
        if ResponseAxes != None:
            plt.axes(ResponseAxes)
            plt.plot(self.Gap_StartleMagnitude, 'ro-')
            plt.hold(True)
            plt.plot(self.noGap_StartleMagnitude, 'bx-')
            plt.hold(False)
            plt.title("d' = %7.3f" % (dprime))

    def Startle_Analyze(self, timebase = None,
                        signal = None,
                        startdelay = 0,
                        trialcounter = 0,
                        ntrials = 1,
                        gaplist = None):
        self.readAnalysisTab()
        dprime = 0.0
        if trialcounter == 0: # initialize the trials.
            self.Gap_mean = 0.0
            self.Gap_std = 0.0
            self.noGap_mean = 0.0
            self.noGap_std = 0.0
            self.Gap_StartleMagnitude = zeros(ntrials)
            self.Gap_Counter = 0
            self.noGap_StartleMagnitude = zeros(ntrials)
            self.noGap_Counter = 0
            return(dprime)
        
        apts = self.getSelectionIndices(timebase, startdelay/1000.0,
                                        (startdelay+self.Analysis_End)/1000.0)

        if trialcounter > 0 : # once we are past the habituation phase
            if gaplist[trialcounter]:
                try:
                    self.Gap_StartleMagnitude[self.Gap_Counter] = sum(signal[apts]**2)
                    self.Gap_mean = mean(self.Gap_StartleMagnitude[0:self.Gap_Counter])
                    if self.Gap_Counter >= 1 :
                        self.Gap_std = std(self.Gap_StartleMagnitude[0:self.Gap_Counter])
                    self.Gap_Counter = self.Gap_Counter + 1
                except:
                    pass
            else:
                try:
                    self.noGap_StartleMagnitude[self.noGap_Counter] = sum(signal[apts]**2)
                    self.noGap_mean = mean(self.noGap_StartleMagnitude[0:self.noGap_Counter])
                    if self.noGap_Counter >= 1 :
                        self.noGap_std = std(self.noGap_StartleMagnitude[0:self.noGap_Counter])
                    self.noGap_Counter = self.noGap_Counter + 1
                except:
                    pass
# now calculate the d'
#            print "gap: %f +/- %f,, nogap: %f +/- %f" % (self.Gap_mean, self.Gap_std, self.noGap_mean, self.noGap_std)
            if self.noGap_std != 0 and self.Gap_std != 0 :
                dprime = (self.noGap_mean-self.Gap_mean)/(sqrt(self.noGap_std**2 + self.Gap_std**2))
        return(dprime)
    
################################################################################
#
# Read the initialization file. This is a simple text file with defined fields
# and numeric arguments. 
# The elements of the file can be in any order.
# unrecognized tags are ignored at your peril.
################################################################################

    def readini(self, filename= None):
        
        if filename == None:
            fd = QtGui.QFileDialog(self)
            self.fileName = str(fd.getOpenFileName(self, "Get Parameter File", "",
                                                   "Parameter Files (*.ini)"))
        else:
            self.fileName = filename
 
        rxCurrentTab = re.compile("^\[CurrTab\] ([0-9]*)")
        rxDate = re.compile("^\[Date\] ([0-9a-zA-Z\-\/.,:]*)")
        rxDescription = re.compile("^\[Description\] ([0-9a-zA-Z.,]*)")
        rxStimEnable = re.compile("(^\[StimEnable\]) ([TrueFalse]*)")
        rxWavePlot = re.compile("(^\[WavePlot\]) ([TrueFalse]*)")
        rxSpecPlot = re.compile("(^\[SpecPlot\]) ([TrueFalse]*)")
        rxCN_Level = re.compile("(^\[CN_Level\]) ([0-9.]*)")
        rxCN_Dur = re.compile("(^\[CN_Dur\]) ([0-9.]*)")
        rxCN_Var = re.compile("(^\[CN_Var\]) ([0-9.]*)")
        rxCN_Mode = re.compile("(^\[CN_Mode\]) ([0-9]*)")
        rxPP_Level = re.compile("(^\[PP_Level\]) ([0-9.]*)")
        rxPP_OffLevel= re.compile("(^\[PP_OffLevel\]) ([0-9.]*)")
        rxPP_Dur = re.compile("(^\[PP_Dur\]) ([0-9.]*)")
        rxPP_Freq = re.compile("(^\[PP_Freq\]) ([0-9.]*)")
        rxPP_HP = re.compile("(^\[PP_HP\]) ([0-9.]*)")
        rxPP_LP = re.compile("(^\[PP_LP\]) ([0-9.]*)")
        rxPP_Mode = re.compile("(^\[PP_Mode\]) ([0-9]*)")
        rxPP_NotchF1 = re.compile("(^\[PP_Notch_F1\]) ([0-9.]*)")
        rxPP_NotchF2 = re.compile("(^\[PP_Notch_F2\]) ([0-9.]*)")
        rxPP_GapFlag = re.compile("(^\[PP_GapFlag\]) ([TrueFalse]*)")
        rxPP_MultiF = re.compile("(^\[PP_MultiFreq\]) ([0-9a-zA-Z.,]*)")
        rxST_Dur = re.compile("(^\[ST_Dur\]) ([0-9.]*)")
        rxST_Level = re.compile("(^\[ST_Level\]) ([0-9.]*)")
        rxITI = re.compile("(^\[ITI_mean\]) ([0-9.]*)")
        rxITI_Var = re.compile("(^\[ITI_Var\]) ([0-9.]*)")
        rxTrials = re.compile("(^\[Trials\]) ([0-9.]*)")
        rxNHabTrials = re.compile("(^\[NHabTrials\]) ([0-9.]*)")
        try:
            hstat = open(self.fileName,"r")
            (p, f)  = os.path.split(self.fileName)
            self.setMainWindow(f)
        except IOError:
            self.Status( "%s not found" % (self.fileName))
            return
        try:
            for line in hstat:
                if rxCurrentTab.match(line): # only at start of line
                    mo = rxCurrentTab.search(line)
                    self.CurrentTab = int(mo.group(1))
                    self.setCurrentTab(self.CurrentTab)
                    continue
                if rxDate.match(line): # only at start of line
                    mo = rxDate.search(line)
                    self.fileDate = str(mo.group(1))
                    print ('Date: self.fileDate: %s ' % self.fileDate)
                    self.setMainWindow(f + " " + self.fileDate)
                    continue
                if rxStimEnable.match(line): # at start of line
                    mo = rxStimEnable.search(line) # search whole line to get decimal number
                    self.StimEnable = (mo.group(2) == "True")
                    self.ui.Stimulus_Enable.setChecked(self.StimEnable)
                    continue
                if rxWavePlot.match(line): # at start of line
                    mo = rxWavePlot.search(line) # search whole line to get decimal number
                    self.WavePlot = (mo.group(2) == "True")
                    self.ui.Waveform_PlotFlag.setChecked(self.WavePlot)
                    continue
                if rxSpecPlot.match(line): # at start of line
                    mo = rxSpecPlot.search(line) # search whole line to get decimal number
                    self.ShowSpectrum = (mo.group(2) == "True")
                    self.ui.OnlineSpectrum_Flag.setChecked(self.ShowSpectrum)
                    continue
                if rxCN_Level.match(line):
                    mo = rxCN_Level.search(line)
                    self.ui.Condition_Level.setValue(float(mo.group(2)))
                    continue
                if rxCN_Dur.match(line):
                    mo = rxCN_Dur.search(line)
                    self.ui.Condition_Dur.setValue(float(mo.group(2)))
                    continue
                if rxCN_Var.match(line):
                    mo = rxCN_Var.search(line)
                    self.ui.Condition_Var.setValue(float(mo.group(2)))
                    continue
                if rxCN_Mode.match(line):
                    mo = rxCN_Mode.search(line)
                    self.ui.Waveform_Conditioning.setCurrentIndex(int(mo.group(2)))
                    continue
                if rxPP_Level.match(line):
                    mo = rxPP_Level.search(line)
                    self.ui.PrePulse_Level.setValue(float(mo.group(2)))
                    continue
                if rxPP_Dur.match(line): # at start of line
                    mo = rxPP_Dur.search(line) 
                    self.ui.PrePulse_Dur.setValue(float(mo.group(2)))
                    continue
                if rxPP_Freq.match(line): # at start of line
                    mo = rxPP_Freq.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_Freq.setValue(float(mo.group(2)))
                    continue
                if rxPP_HP.match(line): # at start of line
                    mo = rxPP_HP.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_HP.setValue(float(mo.group(2)))
                    continue
                if rxPP_LP.match(line): # at start of line
                    mo = rxPP_LP.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_LP.setValue(float(mo.group(2)))
                    continue
                if rxPP_NotchF1.match(line): # at start of line
                    mo = rxPP_NotchF1.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_Notch_F1.setValue(float(mo.group(2)))
                    continue
                if rxPP_NotchF2.match(line): # at start of line
                    mo = rxPP_NotchF2.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_Notch_F2.setValue(float(mo.group(2)))
                    continue
                if rxPP_GapFlag.match(line): # at start of line
                    self.PP_GapFlag = (mo.group(2) == "True")
                    self.ui.PrePulse_GapFlag.setChecked(self.PP_GapFlag)
                    continue

                if rxPP_Mode.match(line): # at start of line
                    mo = rxPP_Mode.search(line) # search whole line to get decimal number
                    self.ui.Waveform_PrePulse.setCurrentIndex(int(mo.group(2)))
                    continue
                if rxST_Dur.match(line): # at start of line
                    mo = rxST_Dur.search(line) # search whole line to get decimal number
                    self.ui.Startle_Dur.setValue(float(mo.group(2)))
                    continue
                if rxST_Level.match(line): # at start of line
                    mo = rxST_Level.search(line) # search whole line to get decimal number
                    self.ui.Startle_Level.setValue(float(mo.group(2)))
                    continue
                if rxITI.match(line): # at start of line
                    mo = rxITI.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_ITI.setValue(float(mo.group(2)))
                    continue
                if rxITI_Var.match(line): # at start of line
                    mo = rxITI_Var.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_ITI_Var.setValue(float(mo.group(2)))
                    continue
                if rxTrials.match(line): # at start of line
                    mo = rxTrials.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_Trials.setValue(float(mo.group(2)))
                    continue
                if rxNHabTrials.match(line): # at start of line
                    mo = rxNHabTrials.search(line) # search whole line to get decimal number
                    self.ui.PrePulse_NHabTrials.setValue(float(mo.group(2)))
                    continue
        finally:
            hstat.close()

# write a file that can be read by readini above. 
            
    def writeini(self, filename=None):
                # now save the program status... ;) to reload later
        if filename == None:
            fd = QtGui.QFileDialog(self)
            self.fileName = str(fd.getSaveFileName())
        else:
            self.fileName = filename
            
        hstat = open(self.fileName, "w")
        self.readParameters() # get the latest from the gui
# should also save the date in the file.
        hstat.write("[Description] %s\n[Date] %s\n[CurrTab] %d\n" %
                    (self.Description, strftime("%d-%b-%Y"), self.getCurrentTab()))
        hstat.write("[StimEnable] %s\n[WavePlot] %s\n[SpecPlot] %s\n" %
                    (self.StimEnable, self.WavePlot, self.ShowSpectrum))
        hstat.write("[CN_Level] %f\n[CN_Dur] %f\n[CN_Var] %f\n[CN_Mode] %d\n"
                    % (self.CN_Level, self.CN_Dur, self.CN_Var, self.CN_Mode))
        hstat.write("[PP_Level] %f \n[PP_OffLevel] %f \n[PP_Dur] %f\n[PP_Mode] %d\n" %  \
            (self.PP_Level, self.PP_OffLevel, self.PP_Dur, self.PP_Mode))
        hstat.write("[PP_Freq] %f \n[PP_HP] %f\n[PP_LP] %f\n" \
            %  (self.PP_Freq, self.PP_HP, self.PP_LP))
        hstat.write("[ST_Dur] %f \n[ST_Level] %f\n" \
            %  (self.ST_Dur, self.ST_Level))
        hstat.write("[ITI_mean] %f \n[ITI_Var] %f \n[Trials] %f \n[NHabTrials] %f\n" \
            %  (self.ITI, self.ITI_Var, self.Trials, self.NHabTrials))
        hstat.write("[PP_NotchF1] %f \n[PP_NotchF2] %f \n[PP_GapFlag] %s \n[PP_MultiFreq] %s\n" \
                    % (self.PP_Notch_F1, self.PP_Notch_F2, self.PP_GapFlag, self.PP_MultiFreq))
        hstat.close()



################################################################################
#
# main entry
#

if __name__ == "__main__":
    app = QtGui.QApplication(sys.argv)
    MainWindow = PyStartle()
    MainWindow.show()
    sys.exit(app.exec_())
    